use std::rc::Rc;

use gpui::{App, Bounds, Hsla, PathBuilder, Pixels, SharedString, TextAlign, Window, fill, px};
use gpui_component_macros::IntoPlot;
use num_traits::{Num, ToPrimitive};

use crate::{
    ActiveTheme, PixelsExt,
    plot::{
        AXIS_GAP, AxisText, Grid, Plot, PlotAxis, origin_point,
        scale::{Scale, ScaleBand, ScaleLinear, Sealed},
    },
};

#[derive(IntoPlot)]
pub struct CandlestickChart<T, X, Y>
where
    T: 'static,
    X: PartialEq + Into<SharedString> + 'static,
    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
    data: Vec<T>,
    x: Option<Rc<dyn Fn(&T) -> X>>,
    open: Option<Rc<dyn Fn(&T) -> Y>>,
    high: Option<Rc<dyn Fn(&T) -> Y>>,
    low: Option<Rc<dyn Fn(&T) -> Y>>,
    close: Option<Rc<dyn Fn(&T) -> Y>>,
    tick_margin: usize,
    body_width_ratio: f32,
}

impl<T, X, Y> CandlestickChart<T, X, Y>
where
    X: PartialEq + Into<SharedString> + 'static,
    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
    pub fn new<I>(data: I) -> Self
    where
        I: IntoIterator<Item = T>,
    {
        Self {
            data: data.into_iter().collect(),
            x: None,
            open: None,
            high: None,
            low: None,
            close: None,
            tick_margin: 1,
            body_width_ratio: 0.8,
        }
    }

    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
        self.x = Some(Rc::new(x));
        self
    }

    pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
        self.open = Some(Rc::new(open));
        self
    }

    pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
        self.high = Some(Rc::new(high));
        self
    }

    pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
        self.low = Some(Rc::new(low));
        self
    }

    pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
        self.close = Some(Rc::new(close));
        self
    }

    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
        self.tick_margin = tick_margin;
        self
    }

    pub fn body_width_ratio(mut self, ratio: f32) -> Self {
        self.body_width_ratio = ratio;
        self
    }
}

impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
where
    X: PartialEq + Into<SharedString> + 'static,
    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
        let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
            self.x.as_ref(),
            self.open.as_ref(),
            self.high.as_ref(),
            self.low.as_ref(),
            self.close.as_ref(),
        ) else {
            return;
        };

        let width = bounds.size.width.as_f32();
        let height = bounds.size.height.as_f32() - AXIS_GAP;

        // X scale
        let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
            .padding_inner(0.4)
            .padding_outer(0.2);
        let band_width = x.band_width();

        // Y scale
        let all_values: Vec<Y> = self
            .data
            .iter()
            .flat_map(|d| vec![high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
            .collect();
        let y = ScaleLinear::new(all_values, vec![height, 10.]);

        // Draw X axis
        let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
            if (i + 1) % self.tick_margin == 0 {
                x.tick(&x_fn(d)).map(|x_tick| {
                    AxisText::new(
                        x_fn(d).into(),
                        x_tick + band_width / 2.,
                        cx.theme().muted_foreground,
                    )
                    .align(TextAlign::Center)
                })
            } else {
                None
            }
        });

        PlotAxis::new()
            .x(height)
            .x_label(x_label)
            .stroke(cx.theme().border)
            .paint(&bounds, window, cx);

        // Draw grid
        Grid::new()
            .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
            .stroke(cx.theme().border)
            .dash_array(&[px(4.), px(2.)])
            .paint(&bounds, window);

        // Draw candlesticks
        let origin = bounds.origin;
        let x_fn = x_fn.clone();
        let open_fn = open_fn.clone();
        let high_fn = high_fn.clone();
        let low_fn = low_fn.clone();
        let close_fn = close_fn.clone();

        for d in &self.data {
            let x_tick = x.tick(&x_fn(d));
            let Some(x_tick) = x_tick else {
                continue;
            };

            // Get OHLC values for the current data point
            let open = open_fn(d);
            let high = high_fn(d);
            let low = low_fn(d);
            let close = close_fn(d);

            // Convert values to pixel coordinates
            let open_y = y.tick(&open);
            let high_y = y.tick(&high);
            let low_y = y.tick(&low);
            let close_y = y.tick(&close);

            let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
                (open_y, high_y, low_y, close_y)
            else {
                continue;
            };

            // Determine if bullish (close > open) or bearish (close < open)
            let is_bullish = close > open;
            let color: Hsla = if is_bullish {
                cx.theme().bullish
            } else {
                cx.theme().bearish
            };

            // Calculate candlestick body dimensions
            let center_x = x_tick + band_width / 2.;
            let body_width = band_width * self.body_width_ratio;
            let body_left = center_x - body_width / 2.;
            let body_right = center_x + body_width / 2.;

            // Draw wick (high to low line)
            let mut wick_builder = PathBuilder::stroke(px(1.));
            wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
            wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));

            if let Ok(path) = wick_builder.build() {
                window.paint_path(path, color);
            }

            // Draw body (open to close rectangle)
            // For bullish: top is close, bottom is open
            // For bearish: top is open, bottom is close
            let (top, bottom) = if is_bullish {
                (close_y, open_y)
            } else {
                (open_y, close_y)
            };

            let body_bounds = Bounds::from_corners(
                origin_point(px(body_left), px(top), origin),
                origin_point(px(body_right), px(bottom), origin),
            );

            window.paint_quad(fill(body_bounds, color));
        }
    }
}
